#!/bin/bash

# The SBATCH directives must appear before any executable line in this script.

#SBATCH --qos high2         # QOS (priority).
#SBATCH -N 1               # Number of nodes requested.
#SBATCH -n 1               # Number of tasks (i.e. processes).
#SBATCH --cpus-per-task=1  # Number of cores per task.
#SBATCH --gres=gpu:8       # Number of GPUs.
#SBATCH -t 60-00:00          # Time requested (D-HH:MM).
#SBATCH --nodelist=em5    # Uncomment if you need a specific machine.

# Uncomment this to have Slurm cd to a directory before running the script.
# You can also just run the script from the directory you want to be in.
#SBATCH -D /home/code/test_time_training/ttt_mae

# Uncomment to control the output files. By default stdout and stderr go to
# the same place, but if you use both commands below they'll be split up.
# %N is the hostname (if used, will create output(s) per node).
# %j is jobid.
##SBATCH -o slurm.%N.%j.out    # STDOUT
##SBATCH -e slurm.%N.%j.err    # STDERR

# Print some info for context.
source ~/.bashrc
conda activate taming
cd /home/code/test_time_training/ttt_mae

nvidia-smi
# Python will buffer output of your script unless you set this.
# If you're not using python, figure out how to turn off output
# buffering when stdout is a file, or else when watching your output
# script you'll only get updated every several lines printed.
export PYTHONUNBUFFERED=1

OUTPUT_DIR='/home/code/test_time_training/ttt_mae/output_dir/shared_training_small_no_bs'
mkdir ${OUTPUT_DIR}
DATA_PATH='/home/group/ilsvrc/'
TIME=$(date +%s%3N)
# Do all the research.
python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \
        --data_path ${DATA_PATH} \
        --model mae_vit_small_patch16 \
        --input_size 224 \
        --batch_size 128 \
        --mask_ratio 0.75 \
        --warmup_epochs 5 \
        --epochs 90 \
        --blr 1e-4 \
        --save_ckpt_freq 30 \
        --output_dir ${OUTPUT_DIR}  \
        --dist_url "file://$OUTPUT_DIR/$TIME" \
        --shared_training


# CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 main_pretrain.py \
#         --data_path ${DATA_PATH} \
#         --model mae_vit_small_patch16 \
#         --input_size 32 \
#         --batch_size 128 \
#         --mask_ratio 0.75 \
#         --warmup_epochs 5 \
#         --epochs 90 \
#         --blr 1e-4 \
#         --save_ckpt_freq 50 \
#         --output_dir ${OUTPUT_DIR}  \
#         --shared_training

# Print completion time.
date